#!/usr/bin/env python3
# Copyright 2022 The ChromiumOS Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

# Contains custom presubmit checks implemented in python.
#
# These are implemented as a separate CLI tool from tools/presubmit as the presubmit
# framework needs to call a subprocess to execute checks.

from fnmatch import fnmatch
import os
import re
import json
from datetime import datetime
from pathlib import Path
from typing import Dict, Generator, List, cast

from impl.common import (
    cmd,
    cwd_context,
    run_commands,
)


def check_platform_independent(*files: str):
    "Checks the provided files to ensure they are free of platform independent code."
    cfg_unix = "cfg.*unix"
    cfg_linux = "cfg.*linux"
    cfg_windows = "cfg.*windows"
    cfg_android = "cfg.*android"
    target_os = "target_os = "

    target_os_pattern = re.compile(
        "%s|%s|%s|%s|%s" % (cfg_android, cfg_linux, cfg_unix, cfg_windows, target_os)
    )

    for file in files:
        for line_number, line in enumerate(open(file, encoding="utf8")):
            if re.search(target_os_pattern, line):
                raise Exception(f"Found unexpected platform dependent code in {file}:{line_number}")


CRLF_LINE_ENDING_FILES: List[str] = [
    "**.bat",
    "**.ps1",
    "e2e_tests/tests/goldens/backcompat_test_simple_lspci_win.txt",
    "tools/windows/build_test",
]


def is_crlf_file(file: str):
    for glob in CRLF_LINE_ENDING_FILES:
        if fnmatch(file, glob):
            return True
    return False


def check_line_endings(*files: str):
    "Checks line endings. Windows only files are using clrf. All others just lf."
    for line in cmd("git ls-files --eol", *files).lines():
        parts = line.split()
        file = parts[-1]
        index_endings = parts[0][2:]
        wdir_endings = parts[1][2:]

        def check_endings(endings: str):
            if is_crlf_file(file):
                if endings not in ("crlf", "mixed"):
                    raise Exception(f"{file} Expected crlf file endings. Found {endings}")
            else:
                if endings in ("crlf", "mixed"):
                    raise Exception(f"{file} Expected lf file endings. Found {endings}")

        check_endings(index_endings)
        check_endings(wdir_endings)


def check_rust_lockfiles(*files: str):
    "Verifies that none of the Cargo.lock files require updates."
    lockfiles = [Path("Cargo.lock"), *Path("common").glob("*/Cargo.lock")]
    for path in lockfiles:
        with cwd_context(path.parent):
            if not cmd("cargo update --workspace --locked").success():
                print(f"{path} is not up-to-date.")
                print()
                print("You may need to rebase your changes and run `cargo update --workspace`")
                print("(or ./tools/run_tests) to ensure the Cargo.lock file is current.")
                raise Exception("Cargo.lock out of date")


# These crosvm features are currently not built upstream. Do not add to this list.
KNOWN_DISABLED_FEATURES = [
    "default-no-sandbox",
    "libvda",
    "seccomp_trace",
    "vulkano",
    "whpx",
]


def check_rust_features(*files: str):
    "Verifies that all cargo features are included in the list of features we compile upstream."
    metadata = json.loads(cmd("cargo metadata --format-version=1").stdout())
    crosvm_metadata = next(p for p in metadata["packages"] if p["name"] == "crosvm")
    features = cast(Dict[str, List[str]], crosvm_metadata["features"])

    def collect_features(feature_name: str) -> Generator[str, None, None]:
        yield feature_name
        for feature in features[feature_name]:
            if feature in features:
                yield from collect_features(feature)
            else:
                # optional crate is enabled through sub-feature of the crate.
                # e.g. protos optional crate/feature is enabled by protos/plugin
                optional_crate_name = feature.split("/")[0]
                if (
                    optional_crate_name in features
                    and features[optional_crate_name][0] == f"dep:{optional_crate_name}"
                ):
                    yield optional_crate_name

    all_platform_features = set(
        (
            *collect_features("all-x86_64"),
            *collect_features("all-aarch64"),
            *collect_features("all-armhf"),
            *collect_features("all-mingw64"),
            *collect_features("all-msvc64"),
            *collect_features("all-riscv64"),
            *collect_features("all-android"),
        )
    )
    disabled_features = [
        feature
        for feature in features
        if feature not in all_platform_features and feature not in KNOWN_DISABLED_FEATURES
    ]
    if disabled_features:
        raise Exception(
            f"The features {', '.join(disabled_features)} are not enabled in upstream crosvm builds."
        )


LICENSE_HEADER_RE = (
    r".*Copyright (?P<year>20[0-9]{2})(?:-20[0-9]{2})? The ChromiumOS Authors\n"
    r".*Use of this source code is governed by a BSD-style license that can be\n"
    r".*found in the LICENSE file\.\n"
    r"( *\*/\n)?"  # allow the end of a C-style comment before the blank line
    r"\n"
)

NEW_LICENSE_HEADER = [
    f"Copyright {datetime.now().year} The ChromiumOS Authors",
    "Use of this source code is governed by a BSD-style license that can be",
    "found in the LICENSE file.",
]


def new_licence_header(file_suffix: str):
    if file_suffix in (".py", "", ".policy", ".sh"):
        prefix = "#"
    else:
        prefix = "//"
    return "\n".join(f"{prefix} {line}" for line in NEW_LICENSE_HEADER) + "\n\n"


def check_copyright_header(*files: str, fix: bool = False):
    "Checks copyright header. Can 'fix' them if needed by adding the header."
    license_re = re.compile(LICENSE_HEADER_RE, re.MULTILINE)
    for file_path in (Path(f) for f in files):
        header = file_path.open("r").read(512)
        license_match = license_re.search(header)
        if license_match:
            continue
        # Generated files do not need a copyright header.
        if "generated by" in header:
            continue
        if fix:
            print(f"Adding copyright header: {file_path}")
            contents = file_path.read_text()
            file_path.write_text(new_licence_header(file_path.suffix) + contents)
        else:
            raise Exception(f"Bad copyright header: {file_path}")


def check_file_ends_with_newline(*files: str, fix: bool = False):
    "Checks if files end with a newline."
    for file_path in (Path(f) for f in files):
        with file_path.open("rb") as file:
            # Skip empty files
            file.seek(0, os.SEEK_END)
            if file.tell() == 0:
                continue
            # Check last byte of the file
            file.seek(-1, os.SEEK_END)
            file_end = file.read(1)
            if file_end.decode("utf-8") != "\n":
                if fix:
                    file_path.write_text(file_path.read_text() + "\n")
                else:
                    raise Exception(f"File does not end with a newline {file_path}")


if __name__ == "__main__":
    run_commands(
        check_file_ends_with_newline,
        check_copyright_header,
        check_rust_features,
        check_rust_lockfiles,
        check_line_endings,
        check_platform_independent,
    )
